import torch
import torch.nn as nn
import os
from tqdm import tqdm


def train_reward_model(reward_model, dataloader, optimizer, clip_model, device, epoch):
    """
    训练奖励模型，并在每个 batch 使用 tqdm 显示详细进度
    :param reward_model: 奖励模型
    :param dataloader: DataLoader 对象
    :param optimizer: 优化器
    :param clip_model: CLIP 模型
    :param device: 设备（CPU 或 GPU）
    :param epoch: 当前训练的 epoch
    """
    reward_model.train()
    clip_model.eval()

    running_loss = 0.0  # 用于记录累计损失
    SCALE_FACTOR = 5

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}", leave=False)
    for batch_idx, (promtps, positive_imgs, negative_imgs) in enumerate(progress_bar):
        prompts, positive_imgs, negative_imgs = promtps.to(device), positive_imgs.to(device), negative_imgs.to(device)
        with torch.no_grad():
            # 使用 CLIP 生成图像和文本的 embedding
            text_embeddings = clip_model.encode_text(prompts).float()
            positive_embeddings = clip_model.encode_image(positive_imgs).float()
            negative_embeddings = clip_model.encode_image(negative_imgs).float()

        # 前向传播
        positive_reward = reward_model(positive_embeddings, text_embeddings).squeeze()
        negative_reward = reward_model(negative_embeddings, text_embeddings).squeeze()

        # 计算损失
        reward_diff = positive_reward - negative_reward
        reward_diff *= SCALE_FACTOR 
        loss = -torch.log(torch.sigmoid(reward_diff) + 1e-8)    # +1e-8防止数值问题
        # loss = -torch.log(reward_diff + 1e-8)
        loss = loss.mean()
        running_loss += loss.item()

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 更新 tqdm 进度条的描述信息
        progress_bar.set_postfix({
            "Batch Loss": f"{loss.item():.4f}",
            "Avg Loss": f"{running_loss / (batch_idx + 1):.4f}"
        })

    # 返回平均损失
    avg_loss = running_loss / len(dataloader)
    return avg_loss


def validate_model(model, dataloader, clip_model, device):
    model.eval()
    clip_model.eval()
    total_loss = 0
    SCALE_FACTOR = 5
    
    with torch.no_grad():
        for promtps, positive_imgs, negative_imgs in dataloader:
            prompts, positive_imgs, negative_imgs = promtps.to(device), positive_imgs.to(device), negative_imgs.to(device)
            
            with torch.no_grad():
                text_embeddings = clip_model.encode_text(prompts).float()
                positive_embeddings = clip_model.encode_image(positive_imgs).float()
                negative_embeddings = clip_model.encode_image(negative_imgs).float()

            positive_reward = model(positive_embeddings, text_embeddings).squeeze()
            negative_reward = model(negative_embeddings, text_embeddings).squeeze()
            reward_diff = positive_reward - negative_reward
            reward_diff *= SCALE_FACTOR
            loss = -torch.log(torch.sigmoid(reward_diff) + 1e-8)    # +1e-8防止数值问题
            # loss = -torch.log(reward_diff + 1e-8)
            loss = loss.mean()
            total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Validaion Loss: {avg_loss:.4f}")


def save_model(model, optimizer, epoch, save_path="./saved_models/reward_model.pth"):
    """
    保存模型到指定路径
    :param model: 训练的模型
    :param optimizer: 优化器
    :param epoch: 当前epoch数
    :param save_path: 模型保存路径
    """
    os.makedirs(os.path.dirname(save_path), exist_ok=True)  # 确保保存路径存在
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, save_path)
    print(f"Model saved to {save_path}")


def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    # loss = checkpoint['loss']
    return model, optimizer, epoch